import torch
import os
import sys
import argparse
import json
import time
import random
import numpy as np
from tqdm import tqdm
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
from str_utils import de_md_logits_processor_for_llama3_1, flaming_tokens
import str_utils

################
# Configurations
################
def get_args():
    # Experiment Settings
    parser = argparse.ArgumentParser(description="Instruction Generation Manager.")
    parser.add_argument("--model_path", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct",
                        help="We will support more models in the future.")
    
    # Generation Parameters
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--n", type=int, default=200, help="Number of samples to generate for one time.")
    parser.add_argument("--repeat", type=int, default=None, help="Number of times to repeat the instruction generation. Only available when total prompts is not specified.")
    parser.add_argument("--total_prompts", type=int, default=1000, help="Total number of prompts to generate. If specified, repeat will be ignored.")
    parser.add_argument("--max_tokens", type=int, default=2048)
    parser.add_argument("--max_model_len", type=int, default=4096)

    # Generation Settings
    parser.add_argument("--early_stopping", type=bool, default=True, help="Stop generation when the \n is generated.")
    parser.add_argument("--disable_early_stopping", action="store_false", dest="early_stopping", help="Disable early stopping.")
    parser.add_argument("--system_prompt", action="store_true", help="Enable system prompt for extracting the input.")
    parser.add_argument("--sanitize", action="store_true", help="Sanitize the generated instructions. Only available for Gemma and Llama-3 models.")
    parser.add_argument("--logits_processor", action="store_true", help="Enable logits processor for the generation.")
    parser.add_argument("--flaming_tokens", action="store_true", help="Enable flaming initial tokens (increase temperature) for more diverse generation.")
    parser.add_argument("--control_tasks", type=str, default=None, choices=[None, "translation", "code", "math"],  help="Control tasks for the generation. Currently only available for some models.")
    parser.add_argument("--shuffle", type=bool, default=True, help="Shuffle the outputs generated by vllm.")
    parser.add_argument("--skip_special_tokens", type=bool, default=True)

    # System Settings
    parser.add_argument('--engine', default="vllm", type=str, choices=["vllm", "hf"])
    parser.add_argument("--device", type=str, default="0")
    parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float16", "bfloat16"])
    parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs to use for tensor parallelism. Only used for Llama 70B models.")
    parser.add_argument("--gpu_memory_utilization", type=float, default=0.95)
    parser.add_argument("--swap_space", type=float, default=2.0)
    parser.add_argument("--checkpoint_every", type=int, default=100, help="Save checkpoint every n repeats.")
    parser.add_argument("--output_folder", type=str, default="../data")
    parser.add_argument("--job_name", type=str, default=None, help="Job Name. Get from the script.")
    parser.add_argument("--timestamp", type=int, default=int(time.time()), help="Timestamp for the job. Also used as the random seed.")
    parser.add_argument("--seed", type=int, default=None, help="Random seed.")

    return parser.parse_args()

# Main function to control workflow
def main():
    args = get_args()
    print(f"Instruction Generation Manager. Arguments: {args}") # For logging

    # Raise error if sanitization is requested for unsupported models
    if args.sanitize:
        if not ("gemma" in args.model_path.lower() or "llama-3" in args.model_path.lower()):
            raise ValueError("Sanitization is only supported for Gemma and Llama-3 models.")
    
    if args.total_prompts is None:
        if args.repeat is None:
            raise ValueError("Either total prompts or repeat should be specified.")
        args.total_prompts = args.repeat * args.n
    else:
        # If total prompts is specified, repeat will be ignored
        args.repeat = int(np.ceil(args.total_prompts / args.n))
    
    # Set the random seed for NumPy
    if args.seed is not None:
        np.random.seed(args.seed)
        # Set the random seed for PyTorch
        torch.manual_seed(args.seed)
        # If you are using CUDA (i.e., a GPU), also set the seed for it
        torch.cuda.manual_seed_all(args.seed)
    
    # Create output file / folder
    output_filename = f"Magpie_{args.model_path.split('/')[-1]}_{args.total_prompts}_{args.timestamp}_ins.json"
    if not args.job_name:
        if not os.path.exists(args.output_folder):
            os.makedirs(args.output_folder)
        output_dir = f"{args.output_folder}/{output_filename}"
    else:
        output_dir = f"{args.output_folder}/{args.job_name}/{output_filename}"
    
    # Set the device
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    # Set generation engine
    if args.engine == "vllm":
        # Create vllm instance  
        llm = LLM(model=args.model_path, 
                dtype=args.dtype,
                trust_remote_code=True,
                gpu_memory_utilization=args.gpu_memory_utilization,
                max_model_len=args.max_model_len,
                swap_space=args.swap_space,
                tensor_parallel_size=args.tensor_parallel_size,
                seed=args.seed if args.seed is not None else args.timestamp,
                enable_prefix_caching=True)
    elif args.engine == "hf":
        # Load the model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(args.model_path)
        model = AutoModelForCausalLM.from_pretrained(
            args.model_path,
            device_map={'':torch.cuda.current_device()},
            torch_dtype=torch.bfloat16 if args.dtype == "bfloat16" else torch.float16
        )
    
    
    # Obtain config from configs/model_configs.json
    with open("../configs/model_configs.json", "r", encoding="utf-8") as f:
        model_configs = json.load(f)
        model_config = model_configs[args.model_path]
        if args.control_tasks:
            pre_query_template = model_config[f"pre_query_template_{args.control_tasks}"]
            print("Control task: {args.control_tasks}")
        elif args.system_prompt:
            pre_query_template = model_config["pre_query_template_with_system_prompt"]
            print("System prompt enabled. Warning: The system prompt may degrade the performance.")
        else:
            pre_query_template = model_config["pre_query_template"]
        stop_tokens = model_config["stop_tokens"]
        stop_tokens_assistant = model_config["stop_tokens_assistant"]
        stop_tokens += stop_tokens_assistant
        stop_token_ids = model_config["stop_token_ids"]
    
        # Process early stopping. We found that sometimes LLM will generate responses immediately after the \n token.
        if args.early_stopping:
            stop_tokens.append("\n")
    
        print(f"Pre-query template: {pre_query_template}")
        print(f"Stop tokens: {stop_tokens}")
        print(f"Stop token ids: {stop_token_ids}")
    
    # Apply logits processors
    if args.logits_processor and args.flaming_tokens:
        raise ValueError("Cannot enable both logits processor and flaming tokens")
    
    if args.logits_processor and "llama-3.1" in args.model_path.lower():
        logits_processor = de_md_logits_processor_for_llama3_1
        print(f"Logits processor applied: {logits_processor}")
    elif args.flaming_tokens:
        logits_processor = flaming_tokens
        print(f"Logits processor applied: {logits_processor}")
    else:
        logits_processor = None
        
    # Define sampling parameters
    sampling_params = SamplingParams(
        n=args.n,
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_tokens,
        skip_special_tokens=args.skip_special_tokens,
        stop=stop_tokens,
        stop_token_ids=stop_token_ids,
        logits_processors=[logits_processor] if logits_processor else None
    )
    
    ################
    # Generate outputs
    ################
    results = []
    for rounds in tqdm(range(args.repeat)):
        # Generate outputs
        if args.engine == "vllm":
            output = llm.generate(pre_query_template, sampling_params)
            output_list = output[0].outputs
            if args.shuffle:
                random.shuffle(output_list)
        
        elif args.engine == "hf":
            input = tokenizer.encode(pre_query_template, add_special_tokens=False, return_tensors="pt").to(torch.cuda.current_device())
            # Gemma-2 bug, so we cannot set num_return_sequences > 1. 
            # Instead, we repeat the input n times.
            inputs = input.repeat(args.n, 1).to(torch.cuda.current_device())
            output = model.generate(inputs,
                                    tokenizer=tokenizer, 
                                    do_sample=True, 
                                    temperature=args.temperature, 
                                    top_p=args.top_p, 
                                    max_length=args.max_tokens, 
                                    num_return_sequences=1,
                                    )
            # Remove the input from the output
            output_list = tokenizer.batch_decode(output[i][len(inputs[0]):] for i in range(args.n))
            # Stop on the first stop token
            for i, completion in enumerate(output_list):
                for stop_token in stop_tokens:
                    if stop_token in completion:
                        output_list[i] = completion[:completion.index(stop_token)]
                                                 
        # Save outputs
        for i, completion in enumerate(output_list):
            if args.engine == "vllm":
                instruction = completion.text.strip()
            elif args.engine == "hf":
                instruction = completion.strip()
    
            if args.sanitize:
                sanitized_instruction, class_num = str_utils.instruction_post_process(instruction, args.model_path)
                result = {
                    "id": rounds * args.n + i,
                    "pre_query_template": f"{pre_query_template}",
                    "raw_instruction": instruction,
                    "instruction": sanitized_instruction,
                    "instruction_sanitize_class_num": class_num,
                    "response": None,
                    "created": int(time.time()),
                    "gen_input_configs": {
                        "temperature": args.temperature,
                        "top_p": args.top_p,
                        "input_generator": f"{args.model_path}",
                        "seed": args.seed,
                    },
                    "gen_response_configs": None,
                }
            else:
                result = {
                    "id": rounds * args.n + i,
                    "pre_query_template": f"{pre_query_template}",
                    "instruction": instruction,
                    "response": None,
                    "created": int(time.time()),
                    "gen_input_configs": {
                        "temperature": args.temperature,
                        "top_p": args.top_p,
                        "input_generator": f"{args.model_path}",
                        "seed": args.seed,
                    },
                    "gen_response_configs": None,
                }
            results.append(result)
    
        # Save the checkpoints every args.checkpoint_every rounds
        if rounds % args.checkpoint_every == 0:
            with open(output_dir, "w") as f:
                json.dump(results, f, indent=2)
            print(f"Checkpoint saved. Total prompts: {len(results)}")
    
    # Save the final results
    with open(output_dir, "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"Instruction generated from {args.model_path}. Total prompts: {len(results)}")

# Run the main function
if __name__ == "__main__":
    main()
